{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f400486b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-21T13:24:28.700830400Z",
     "start_time": "2024-01-21T13:24:28.362250900Z"
    }
   },
   "outputs": [],
   "source": [
    "# Copyright (c) Meta Platforms, Inc. and affiliates."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ae39ff",
   "metadata": {},
   "source": [
    "# Object masks from prompts with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4a4b25c",
   "metadata": {},
   "source": [
    "The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt. \n",
    "\n",
    "The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "644532a8",
   "metadata": {},
   "source": [
    "## Environment Set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07fabfee",
   "metadata": {},
   "source": [
    "If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0be845da",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33681dd1",
   "metadata": {},
   "source": [
    "Necessary imports and helper functions for displaying points, boxes, and masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "69b28288",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-21T13:24:54.676095300Z",
     "start_time": "2024-01-21T13:24:54.631098700Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "29bc90d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-21T13:24:55.294302500Z",
     "start_time": "2024-01-21T13:24:55.268303500Z"
    }
   },
   "outputs": [],
   "source": [
    "def show_mask(mask, ax, random_color=False):\n",
    "    if random_color:\n",
    "        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
    "    else:\n",
    "        color = np.array([30/255, 144/255, 255/255, 0.6])\n",
    "    h, w = mask.shape[-2:]\n",
    "    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "    \n",
    "def show_points(coords, labels, ax, marker_size=375):\n",
    "    pos_points = coords[labels==1]\n",
    "    neg_points = coords[labels==0]\n",
    "    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   \n",
    "    \n",
    "def show_box(box, ax):\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23842fb2",
   "metadata": {},
   "source": [
    "## Example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3c2e4f6b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-21T13:24:57.770891500Z",
     "start_time": "2024-01-21T13:24:57.718358500Z"
    }
   },
   "outputs": [
    {
     "ename": "error",
     "evalue": "OpenCV(4.9.0) D:\\a\\opencv-python\\opencv-python\\opencv\\modules\\imgproc\\src\\color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'\n",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31merror\u001B[0m                                     Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[8], line 3\u001B[0m\n\u001B[0;32m      1\u001B[0m image \u001B[38;5;241m=\u001B[39m cv2\u001B[38;5;241m.\u001B[39mimread(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mimages/truck.jpg\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m      2\u001B[0m \u001B[38;5;66;03m#导入图片\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m image \u001B[38;5;241m=\u001B[39m \u001B[43mcv2\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcvtColor\u001B[49m\u001B[43m(\u001B[49m\u001B[43mimage\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcv2\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mCOLOR_BGR2RGB\u001B[49m\u001B[43m)\u001B[49m\n",
      "\u001B[1;31merror\u001B[0m: OpenCV(4.9.0) D:\\a\\opencv-python\\opencv-python\\opencv\\modules\\imgproc\\src\\color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'\n"
     ]
    }
   ],
   "source": [
    "image = cv2.imread('images/truck.jpg')\n",
    "#导入图片\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30125fd",
   "metadata": {
    "scrolled": false,
    "ExecuteTime": {
     "start_time": "2024-01-21T13:24:32.689838Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98b228b8",
   "metadata": {},
   "source": [
    "## Selecting objects with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bb1927b",
   "metadata": {},
   "source": [
    "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7e28150b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-21T13:25:08.453430100Z",
     "start_time": "2024-01-21T13:25:01.572343500Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from segment_anything import sam_model_registry, SamPredictor\n",
    "#选择模型\n",
    "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
    "model_type = \"vit_h\"\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
    "sam.to(device=device)\n",
    "\n",
    "predictor = SamPredictor(sam)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c925e829",
   "metadata": {},
   "source": [
    "Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95d48dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.set_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8fc7a46",
   "metadata": {},
   "source": [
    "To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c69570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#在图片中画出星星，其实就相当于鼠标放置的位置，模型会以这个位置为远点进行分割（影响最终效果）\n",
    "input_point = np.array([[100, 200]])\n",
    "#1是绿色的，0是红色的，其中绿色是物体点，红色是背景点，就相当于绿色是正相关，我们想要，红色是负相关，我们不想要\n",
    "input_label = np.array([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a91ba973",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('on')\n",
    "plt.show()  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c765e952",
   "metadata": {},
   "source": [
    "Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5373fd68",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, scores, logits = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    multimask_output=True,\n",
    "    #这里注意，TRUE的话会返回三个mask，FALSE的话只会返回一个mask，感觉和热值是一个类似的东西，个人感觉对于临摹两颗的东西建议选择FALSE，这样人不用动脑，但也有代价，得具体分析\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7f0e938",
   "metadata": {},
   "source": [
    "With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47821187",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks.shape  # (number_of_masks) x H x W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c227a6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for i, (mask, score) in enumerate(zip(masks, scores)):\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.imshow(image)\n",
    "    show_mask(mask, plt.gca())\n",
    "    show_points(input_point, input_label, plt.gca())\n",
    "    plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
    "    plt.axis('off')\n",
    "    plt.show()  \n",
    "  #遮罩是二维的图片，所以这里的mask是二维的，但是这里的mask是三维的，所以需要遍历一下"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fa31f7c",
   "metadata": {},
   "source": [
    "## Specifying a specific object with additional points"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88d6d29a",
   "metadata": {},
   "source": [
    "The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6923b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "#用附加点指定对象，就相当于在图片中画出两个或多个点，同时指定一个对象，这样得到的结果更为精确\n",
    "input_point = np.array([[500, 375], [1125, 625]])\n",
    "input_label = np.array([1, 1])\n",
    "\n",
    "mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d98f96a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce8b82f",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e06d5c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#这地方也可以搞成向像上一种方法那样进行循环把所有图片都load下来，当然，得将multimask_output=true\n",
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c93e2087",
   "metadata": {},
   "source": [
    "To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a196f68",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[500, 375], [1125, 625]])\n",
    "input_label = np.array([1, 0])\n",
    "\n",
    "mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81a52282",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfca709f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41e2d5a9",
   "metadata": {},
   "source": [
    "## Specifying a specific object with a box"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d61ca7ac",
   "metadata": {},
   "source": [
    "The model can also take a box as input, provided in xyxy format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ea92a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#用一个框来选择一个对象\n",
    "#坐标分别为左上和右下角的坐标\n",
    "input_box = np.array([425, 600, 700, 875])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b35a8814",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=None,\n",
    "    point_labels=None,\n",
    "    box=input_box[None, :],\n",
    "    multimask_output=False,\n",
    "    #也是只要一个\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984b79c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0], plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1ed9f0a",
   "metadata": {},
   "source": [
    "## Combining points and boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8455d1c5",
   "metadata": {},
   "source": [
    "Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90e2e547",
   "metadata": {},
   "outputs": [],
   "source": [
    "#小孩子才选择，我们全都要，结合点和框来进行预测\n",
    "#应用场景：只选择轮胎，而不要轮胎中的钢圈\n",
    "#框\n",
    "input_box = np.array([425, 600, 700, 875])\n",
    "#点\n",
    "input_point = np.array([[575, 750]])\n",
    "#设置为负值点\n",
    "input_label = np.array([0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6956d8c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    box=input_box,\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e13088a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0], plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45ddbca3",
   "metadata": {},
   "source": [
    "## Batched prompt inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df6f18a0",
   "metadata": {},
   "source": [
    "SamPredictor can take multiple input prompts for the same image, using `predict_torch` method. This method assumes input points are already torch tensors and have already been transformed to the input frame. For example, imagine we have several box outputs from an object detector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a06681b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#分批提示输入\n",
    "input_boxes = torch.tensor([\n",
    "    [75, 275, 1725, 850],\n",
    "    [425, 600, 700, 875],\n",
    "    [1375, 550, 1650, 800],\n",
    "    [1240, 675, 1400, 750],\n",
    "], device=predictor.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf957d16",
   "metadata": {},
   "source": [
    "Transform the boxes to the input frame, then predict masks. `SamPredictor` stores the necessary transform as the `transform` field for easy access, though it can also be instantiated directly for use in e.g. a dataloader (see `segment_anything.utils.transforms`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "117521a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])\n",
    "masks, _, _ = predictor.predict_torch(\n",
    "    point_coords=None,\n",
    "    point_labels=None,\n",
    "    boxes=transformed_boxes,\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a8f5d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks.shape  # (batch_size) x (num_predicted_masks_per_input) x H x W\n",
    "#同时预测4个，每次只有一个mask,所以这里是4*1*H*W,h,w是图片的高和宽"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c00c3681",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "for mask in masks:\n",
    "    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)\n",
    "for box in input_boxes:\n",
    "    show_box(box.cpu().numpy(), plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bea70c0",
   "metadata": {},
   "source": [
    "## End-to-end batched inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89c3ba52",
   "metadata": {},
   "source": [
    "If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c01ae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#端到端的分批处理\n",
    "\n",
    "#这里的image1和image2是两张图片，image1_boxes和image2_boxes是两张图片的框\n",
    "#相当于进行批量处理，我觉得这个地方可以结合大模型搞一搞\n",
    "image1 = image  # truck.jpg from above\n",
    "image1_boxes = torch.tensor([\n",
    "    [75, 275, 1725, 850],\n",
    "    [425, 600, 700, 875],\n",
    "    [1375, 550, 1650, 800],\n",
    "    [1240, 675, 1400, 750],\n",
    "], device=sam.device)\n",
    "\n",
    "image2 = cv2.imread('images/groceries.jpg')\n",
    "image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
    "image2_boxes = torch.tensor([\n",
    "    [450, 170, 520, 350],\n",
    "    [350, 190, 450, 350],\n",
    "    [500, 170, 580, 350],\n",
    "    [580, 170, 640, 350],\n",
    "], device=sam.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce56c57d",
   "metadata": {},
   "source": [
    "Both images and prompts are input as PyTorch tensors that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:\n",
    "* `image`: The input image as a PyTorch tensor in CHW format.\n",
    "* `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.\n",
    "* `point_coords`: Batched coordinates of point prompts.\n",
    "* `point_labels`: Batched labels of point prompts.\n",
    "* `boxes`: Batched input boxes.\n",
    "* `mask_inputs`: Batched input masks.\n",
    "\n",
    "If a prompt is not present, the key can be excluded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f908ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from segment_anything.utils.transforms import ResizeLongestSide\n",
    "resize_transform = ResizeLongestSide(sam.image_encoder.img_size)\n",
    "\n",
    "def prepare_image(image, transform, device):\n",
    "    image = transform.apply_image(image)\n",
    "    image = torch.as_tensor(image, device=device.device) \n",
    "    return image.permute(2, 0, 1).contiguous()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f63723",
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_input = [\n",
    "     {\n",
    "         'image': prepare_image(image1, resize_transform, sam),\n",
    "         'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),\n",
    "         'original_size': image1.shape[:2]\n",
    "     },\n",
    "     {\n",
    "         'image': prepare_image(image2, resize_transform, sam),\n",
    "         'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),\n",
    "         'original_size': image2.shape[:2]\n",
    "     }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fbeb831",
   "metadata": {},
   "source": [
    "Run the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b311b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_output = sam(batched_input, multimask_output=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27bb50fd",
   "metadata": {},
   "source": [
    "The output is a list over results for each input image, where list elements are dictionaries with the following keys:\n",
    "* `masks`: A batched torch tensor of predicted binary masks, the size of the original image.\n",
    "* `iou_predictions`: The model's prediction of the quality for each mask.\n",
    "* `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3dba0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_output[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1108f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(20, 20))\n",
    "\n",
    "ax[0].imshow(image1)\n",
    "for mask in batched_output[0]['masks']:\n",
    "    show_mask(mask.cpu().numpy(), ax[0], random_color=True)\n",
    "for box in image1_boxes:\n",
    "    show_box(box.cpu().numpy(), ax[0])\n",
    "ax[0].axis('off')\n",
    "\n",
    "ax[1].imshow(image2)\n",
    "for mask in batched_output[1]['masks']:\n",
    "    show_mask(mask.cpu().numpy(), ax[1], random_color=True)\n",
    "for box in image2_boxes:\n",
    "    show_box(box.cpu().numpy(), ax[1])\n",
    "ax[1].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
